# -*- coding: utf-8 -*-
# @Author   : Paul
# @File : mysql_common_source.py
# @DateTime : 2021/8/4 15:35
# @Description  :
from core.data_source.data_source import DataSource
import collections
import sys
import time
import pymysql as pymysql


class MysqlCommonDataSource(DataSource):

    def __init__(self, host, port, user, password, database):
        """
        Mysql数据库连接，并重试三次数据库连接
        @param host: Ip
        @param port: 端口
        @param user: 用户名
        @param password: 密码
        @param database: 数据库
        """
        super(MysqlCommonDataSource, self).__init__()
        self.error = ''
        index = 0
        while (index <= 3):
            try:
                self.conn = pymysql.connect(host=host, port=port, user=user, password=password, database=database)
                self.cur = self.conn.cursor()
                break
            except pymysql.Error as e:
                print(e)
                print('Error:Mysql Database connection failed')
                index += 1
                if (index > 3):
                    print('第三次重试失败，退出程序！')
                    sys.exit()
                print('等待1分钟后，开始第%d次连接重试' % index)
                time.sleep(60)

    def query(self, sql):
        """
        执行sql语句
        :param sql: sql
        :return: 无返回值
        """
        try:
            self.cur.execute(sql)
            return True
        except Exception as e:
            print("Mysql Error:%s" % (e.__str__()))
            self.error = e
            return False

    def queryOne(self, sql):
        """
        返回所查询sql一条结果
        :param sql: sql
        :return: 返回单条查询的结果或False
        """
        res = self.query(sql)
        if res:
            try:
                result = self.cur.fetchone()
                return result
            except Exception as e:
                return True
        else:
            print("SQL execution failed! No results returned")
            return False

    def queryAll(self, sql):
        """
        返回所查询sql全部结果
        :param sql: sql
        :return: 返回全部查询的结果或False
        """
        res = self.query(sql)
        if res:
            try:
                result = self.cur.fetchall()
                return result
            except Exception as e:
                return True
        else:
            return False

    def insert(self, table_name, datadict):
        """
        插入数据的方法
        :param table_name: 表名
        :param datadict: 要插入的参数字典 例{“id”=1}
        :return: 返回受影响的条数
        """
        data_copy = datadict.copy()
        for key in data_copy:
            data_copy[key] = "\"" + str(data_copy[key]) + "\""
        key = ','.join(data_copy.keys())
        value = ','.join(data_copy.values())
        insert_sql = "INSERT" + " INTO " + table_name + " (" + key + ") VALUES (" + value + ")"
        self.query(insert_sql)
        self.conn.commit()
        return self.cur.rowcount

    def update(self, table_name, datadict, where):
        """
        更改数据的方法
        :param table_name:表名
        :param datadict:修改的参数字典 例{“id”=1}
        :param where:过滤条件
        :return: 返回受影响的条数
        """
        data_copy = datadict.copy()
        for key in data_copy:
            data_copy[key] = key + "='" + str(data_copy[key]) + "'"
        value = ','.join(data_copy.values())
        update_sql = "UPDATE " + table_name + " SET " + value + " WHERE " + where

        self.query(update_sql)
        self.conn.commit()
        return self.cur.rowcount

    def queryOrderedDict(self, sql):
        """
        返回所查询sql全部结果的列表
        :param sql: sql
        :return: 返回列表
        """
        result = self.queryAll(sql)

        desc = self.cur.description
        orderdict = []
        for inv in result:
            _d = collections.OrderedDict()
            for i in range(0, len(inv)):
                theval = str(inv[i])
                if theval == 'None':
                    _d[desc[i][0]] = ''
                else:
                    _d[desc[i][0]] = theval.replace("\"","'")
            orderdict.append(_d)
        return orderdict

    def commit(self):
        self.conn.commit()

    def close(self):
        self.conn.close()
